# -*- coding:utf-8 -*-

import codecs
from utils.data_utils import features_labels_split
from dto.devide_dataset_dto import DividedDataSetDto
from config.glob.global_pool import global_pool

"""
加载text文本
"""


def read(ds_dir):
    """
    加载原始数据集 train_set, validate_set, test_set
    :return:
    """
    ds_dir = ds_dir + 'poetry.txt'
    with codecs.open(ds_dir, encoding='utf-8') as f:
        text = f.read()
    data_pre = global_pool.embedding.text_to_arr(text)
    # 分开features和labels的格式
    train_set = features_labels_split(data_pre, global_pool.config.xs_shape[0])
    # 划分数据集
    # train_set, validate_set = train_vali_split(dataset, vali_size=0.3)
    # 构建传参dto
    divided_dataset = DividedDataSetDto(
        train_set[0], train_set[1], train_set[2]
    )
    return divided_dataset
